import akshare as ak
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
import matplotlib.pyplot as plt


# 使用 akshare 获取股票数据
stock_code = "600519"  # 贵州茅台股票代码
#stock_df = ak.stock_zh_a_spot_em()
#stock_df = ak.stock_zh_a_daily(symbol=stock_code, start_date="2022-01-01", end_date="2024-01-01")
#stock_df = ak.stock_zh_a_daily(symbol="sz000001", start_date="2022-01-01", end_date="2022-12-31")
#stock_df = ak.stock_zh_a_hist(symbol="sz000001", start_date="2022-01-01", end_date="2022-12-31")
stock_df = ak.stock_zh_a_hist(symbol=stock_code, period="daily", start_date="20170301", end_date='20240528', adjust="hfq")


stock_df = stock_df[stock_df["股票代码"] == stock_code][["日期", "收盘"]].sort_values("日期")
print( stock_df)

# 准备数据
data = stock_df["收盘"].values
data = data.reshape(-1, 1)

# 归一化
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(data)

# 创建数据集
X, y = [], []
for i in range(60, len(scaled_data)):
    X.append(scaled_data[i-60:i, 0])
    y.append(scaled_data[i, 0])
X, y = np.array(X), np.array(y)

X = np.reshape(X, (X.shape[0], X.shape[1], 1))

# 构建 LSTM 模型
model = Sequential()
model.add(LSTM(units=50, return_sequences=True, input_shape=(X.shape[1], 1)))
model.add(LSTM(units=50))
model.add(Dense(1))

model.compile(loss='mean_squared_error', optimizer='adam')
model.fit(X, y, epochs=1, batch_size=1, verbose=2)

# 进行预测
predicted_stock_price = model.predict(X)
predicted_stock_price = scaler.inverse_transform(predicted_stock_price)

# 打印预测结果
print(predicted_stock_price)



plt.figure(figsize=(14, 7))
plt.plot(range(0,len(data)), data, label="茅台股票")
plt.plot( range(0,len(predicted_stock_price)), predicted_stock_price, label="预估茅台股票")
plt.xlabel("Date")
plt.ylabel("Price")
plt.title("Price Forecast")  # 添加一个标题，可选
plt.legend()

# 保存图表到文件
plt.savefig('forecast_plot.png')  # 文件名可以根据需要自定义

# 注意：如果你在无图形界面的环境（如服务器）上运行此代码，确保之前设置了后端为'Agg'
plt.switch_backend('Agg')  # 这行代码在需要时取消注释

# 通常在保存图表后，为了避免在无GUI环境中打开图像窗口，可以显式关闭figure
plt.close()


